{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Splat Math\n",
    "\n",
    "### Camera Projection\n",
    "$$ \\left[\\begin{matrix} u \\\\ v \\end{matrix}\\right] = \\left[\\begin{matrix} f_x & 0 & c_x \\\\ 0 & f_y & c_y \\end{matrix}\\right] \\left[\\begin{matrix} \\frac{x}{z} \\\\ \\frac{y}{z} \\\\ 1 \\end{matrix}\\right] $$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch.autograd import Function, gradcheck\n",
    "\n",
    "\n",
    "class CameraProject(Function):\n",
    "    @staticmethod\n",
    "    def forward(ctx, x, y, z, fx, fy, cx, cy):\n",
    "        u = fx * x / z + cx\n",
    "        v = fy * y / z + cy\n",
    "        ctx.save_for_backward(x, y, z, fx, fy)\n",
    "        return u, v\n",
    "\n",
    "    @staticmethod\n",
    "    def backward(ctx, grad_u, grad_v):\n",
    "        x, y, z, fx, fy = ctx.saved_tensors\n",
    "        grad_x = grad_u * fx / z\n",
    "        grad_y = grad_v * fy / z\n",
    "        grad_z = -grad_u * fx * x / z**2 - grad_v * fy * y / z**2\n",
    "        return grad_x, grad_y, grad_z, None, None, None, None\n",
    "\n",
    "\n",
    "x = torch.tensor(10.0, dtype=torch.float64, requires_grad=True)\n",
    "y = torch.tensor(-5.0, dtype=torch.float64, requires_grad=True)\n",
    "z = torch.tensor(10.0, dtype=torch.float64, requires_grad=True)\n",
    "fx = torch.tensor(1300.0, dtype=torch.float64, requires_grad=False)\n",
    "fy = torch.tensor(1200.0, dtype=torch.float64, requires_grad=False)\n",
    "cx = torch.tensor(320.0, dtype=torch.float64, requires_grad=False)\n",
    "cy = torch.tensor(240.0, dtype=torch.float64, requires_grad=False)\n",
    "\n",
    "test = gradcheck(CameraProject.apply, (x, y, z, fx, fy, cx, cy))\n",
    "print(test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Matrix Multiplication\n",
    "\n",
    "The reverse mode differentiation for matrix operations is documented in: [An extended collection of matrix derivative results for forward and reverse mode algorithmic differentiation](https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf) by Mike Giles\n",
    "\n",
    "Matrix Multiplication is documented in Section 2.2.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch.autograd import Function, gradcheck\n",
    "\n",
    "\n",
    "class MatrixMultiplication(Function):\n",
    "    @staticmethod\n",
    "    def forward(ctx, A, B):\n",
    "        C = A @ B\n",
    "        ctx.save_for_backward(A, B)\n",
    "        return C\n",
    "\n",
    "    @staticmethod\n",
    "    def backward(ctx, grad_C):\n",
    "        (\n",
    "            A,\n",
    "            B,\n",
    "        ) = ctx.saved_tensors\n",
    "        grad_A = grad_C @ B.T\n",
    "        grad_B = A.T @ grad_C\n",
    "        return grad_A, grad_B\n",
    "\n",
    "\n",
    "R = torch.rand(3, 3, dtype=torch.float64, requires_grad=True)\n",
    "S = torch.rand(3, 3, dtype=torch.float64, requires_grad=True)\n",
    "\n",
    "test = gradcheck(MatrixMultiplication.apply, (R, S))\n",
    "print(test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Computing `RSSR` is just two matrix multiplications:\n",
    "1. $RS = R * S$\n",
    "2. $RSSR = RS * (RS)^T$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch.autograd import Function, gradcheck\n",
    "\n",
    "\n",
    "class RSSR(Function):\n",
    "    @staticmethod\n",
    "    def forward(ctx, R, S):\n",
    "        RS = R @ S\n",
    "        RSSR = RS @ RS.T\n",
    "        ctx.save_for_backward(R, S)\n",
    "        return RSSR\n",
    "\n",
    "    @staticmethod\n",
    "    def backward(ctx, grad_RSSR):\n",
    "        R, S = ctx.saved_tensors\n",
    "        RS = R @ S\n",
    "        grad_RS = grad_RSSR @ RS\n",
    "        grad_SR = RS.T @ grad_RSSR\n",
    "\n",
    "        grad_R = grad_RS @ S.T\n",
    "        grad_S = R.T @ grad_RS\n",
    "\n",
    "        grad_R_t = S @ grad_SR\n",
    "        grad_S_t = grad_SR @ R\n",
    "\n",
    "        return grad_R + grad_R_t.T, grad_S + grad_S_t.T\n",
    "\n",
    "\n",
    "R = torch.rand(3, 3, dtype=torch.float64, requires_grad=True)\n",
    "S = torch.rand(3, 3, dtype=torch.float64, requires_grad=True)\n",
    "\n",
    "test = gradcheck(RSSR.apply, (R, S))\n",
    "print(test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Computing Sigma Image matches the \"First Quadratic Form\" 2.3.2\n",
    "\n",
    "$ C = B^T A B $ \n",
    "\n",
    "$ \\Sigma_{image} = JW \\Sigma_{world} (JW)^T $\n",
    "\n",
    "Where $ B = (JW)^T $ and $ A = \\Sigma_{world} $ \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch.autograd import Function, gradcheck\n",
    "\n",
    "\n",
    "class ComputeSigmaImage(Function):\n",
    "    @staticmethod\n",
    "    def forward(ctx, sigma_world, W, J):\n",
    "        JW = J @ W\n",
    "        sigma_image = JW @ sigma_world @ JW.T\n",
    "        ctx.save_for_backward(sigma_world, W, J)\n",
    "        return sigma_image\n",
    "\n",
    "    @staticmethod\n",
    "    def backward(ctx, grad_sigma_image):\n",
    "        sigma_world, W, J = ctx.saved_tensors\n",
    "        JW = J @ W\n",
    "\n",
    "        # using First Quadratic Form 2.3.2 from: https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf\n",
    "        # for C = B_t @ A @ B\n",
    "        # grad_A = B @ grad_C @ B_t\n",
    "        # grad_B = A @ B @ grad_C_t + A_t @ B @ grad_C\n",
    "\n",
    "        # applying to our variables\n",
    "        # sigma_image = JW @ sigma_world @ JW.T\n",
    "        # C = sigma_image\n",
    "        # A = sigma_world\n",
    "        # B = JW_t\n",
    "\n",
    "        grad_sigma_world = JW.T @ grad_sigma_image @ JW\n",
    "        grad_JW_t = (\n",
    "            sigma_world @ JW.T @ grad_sigma_image.T + sigma_world.T @ JW.T @ grad_sigma_image\n",
    "        )\n",
    "\n",
    "        # compute gradient of JW_t using multiplication rules in 2.2.2\n",
    "        grad_W_t = grad_JW_t @ J\n",
    "        grad_J_t = W @ grad_JW_t\n",
    "\n",
    "        grad_W = grad_W_t.T\n",
    "        grad_J = grad_J_t.T\n",
    "\n",
    "        return grad_sigma_world, grad_W, grad_J\n",
    "\n",
    "\n",
    "sigma_world = torch.rand(3, 3, dtype=torch.float64, requires_grad=True)\n",
    "W = torch.rand(3, 3, dtype=torch.float64, requires_grad=True)\n",
    "J = torch.rand(2, 3, dtype=torch.float64, requires_grad=True)\n",
    "test = gradcheck(ComputeSigmaImage.apply, (sigma_world, W, J))\n",
    "print(test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\left[\\begin{matrix}- 2 y^{2} - 2 z^{2} + 1 & - 2 w z + 2 x y & 2 w y + 2 x z\\\\2 w z + 2 x y & - 2 x^{2} - 2 z^{2} + 1 & - 2 w x + 2 y z\\\\- 2 w y + 2 x z & 2 w x + 2 y z & - 2 x^{2} - 2 y^{2} + 1\\end{matrix}\\right]\n",
      "\\left[\\begin{matrix}- 4 z & - 2 w & 2 x\\\\2 w & - 4 z & 2 y\\\\2 x & 2 y & 0\\end{matrix}\\right]\n"
     ]
    }
   ],
   "source": [
    "import sympy as sp\n",
    "from sympy import print_latex\n",
    "\n",
    "\n",
    "def quaternion_to_rotation_Symbolic(q):\n",
    "    # norm = sp.sqrt(q[0]**2 + x**2 + y**2 + z**2)\n",
    "    norm = 1\n",
    "    w = q[0] / norm\n",
    "    x = q[1] / norm\n",
    "    y = q[2] / norm\n",
    "    z = q[3] / norm\n",
    "    # Compute the rotation matrix\n",
    "    rotation_matrix = sp.Matrix(\n",
    "        [\n",
    "            [1 - 2 * y**2 - 2 * z**2, 2 * x * y - 2 * w * z, 2 * x * z + 2 * w * y],\n",
    "            [2 * x * y + 2 * w * z, 1 - 2 * x**2 - 2 * z**2, 2 * y * z - 2 * w * x],\n",
    "            [2 * x * z - 2 * w * y, 2 * y * z + 2 * w * x, 1 - 2 * x**2 - 2 * y**2],\n",
    "        ]\n",
    "    )\n",
    "\n",
    "    return rotation_matrix\n",
    "\n",
    "\n",
    "w, x, y, z = sp.symbols(\"w x y z\")\n",
    "q = [w, x, y, z]\n",
    "\n",
    "rotation_matrix = quaternion_to_rotation_Symbolic(q)\n",
    "print_latex(rotation_matrix)\n",
    "rotation_derivative = sp.diff(rotation_matrix, z)\n",
    "print_latex(rotation_derivative)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Rotation matrix from Quaternion:\n",
    "\n",
    "$$\\left[\\begin{matrix}- 2 y^{2} - 2 z^{2} + 1 & - 2 w z + 2 x y & 2 w y + 2 x z\\\\2 w z + 2 x y & - 2 x^{2} - 2 z^{2} + 1 & - 2 w x + 2 y z\\\\- 2 w y + 2 x z & 2 w x + 2 y z & - 2 x^{2} - 2 y^{2} + 1\\end{matrix}\\right] $$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "Jacobian of Quaternion to Rotation Matrix without normalization\n",
    "$$  \\frac{\\partial q}{\\partial w} = \\left[\\begin{matrix}0 & - 2 z & 2 y\\\\2 z & 0 & - 2 x\\\\- 2 y & 2 x & 0\\end{matrix}\\right] $$\n",
    "\n",
    "$$  \\frac{\\partial q}{\\partial x} =  \\left[\\begin{matrix}0 & 2 y & 2 z\\\\2 y & - 4 x & - 2 w\\\\2 z & 2 w & - 4 x\\end{matrix}\\right] $$ \n",
    "\n",
    "$$  \\frac{\\partial q}{\\partial y} =  \\left[\\begin{matrix}- 4 y & 2 x & 2 w\\\\2 x & 0 & 2 z\\\\- 2 w & 2 z & - 4 y\\end{matrix}\\right] $$ \n",
    "\n",
    "$$  \\frac{\\partial q}{\\partial z} =  \\left[\\begin{matrix}- 4 z & - 2 w & 2 x\\\\2 w & - 4 z & 2 y\\\\2 x & 2 y & 0\\end{matrix}\\right] $$ "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch.autograd import Function, gradcheck\n",
    "\n",
    "\n",
    "class QuaternionToRotation(Function):\n",
    "    @staticmethod\n",
    "    def forward(ctx, q):\n",
    "        rot = [\n",
    "            1 - 2 * q[:, 2] ** 2 - 2 * q[:, 3] ** 2,\n",
    "            2 * q[:, 1] * q[:, 2] - 2 * q[:, 0] * q[:, 3],\n",
    "            2 * q[:, 3] * q[:, 1] + 2 * q[:, 0] * q[:, 2],\n",
    "            2 * q[:, 1] * q[:, 2] + 2 * q[:, 0] * q[:, 3],\n",
    "            1 - 2 * q[:, 1] ** 2 - 2 * q[:, 3] ** 2,\n",
    "            2 * q[:, 2] * q[:, 3] - 2 * q[:, 0] * q[:, 1],\n",
    "            2 * q[:, 3] * q[:, 1] - 2 * q[:, 0] * q[:, 2],\n",
    "            2 * q[:, 2] * q[:, 3] + 2 * q[:, 0] * q[:, 1],\n",
    "            1 - 2 * q[:, 1] ** 2 - 2 * q[:, 2] ** 2,\n",
    "        ]\n",
    "        rot = torch.stack(rot, dim=1).reshape(-1, 3, 3)\n",
    "        ctx.save_for_backward(q)\n",
    "        return rot\n",
    "\n",
    "    @staticmethod\n",
    "    def backward(ctx, grad_rot):\n",
    "        q = ctx.saved_tensors[0]\n",
    "\n",
    "        w = q[:, 0]\n",
    "        x = q[:, 1]\n",
    "        y = q[:, 2]\n",
    "        z = q[:, 3]\n",
    "\n",
    "        grad_qw = (\n",
    "            -2 * z * grad_rot[:, 0, 1]\n",
    "            + 2 * y * grad_rot[:, 0, 2]\n",
    "            + 2 * z * grad_rot[:, 1, 0]\n",
    "            - 2 * x * grad_rot[:, 1, 2]\n",
    "            - 2 * y * grad_rot[:, 2, 0]\n",
    "            + 2 * x * grad_rot[:, 2, 1]\n",
    "        )\n",
    "        grad_qx = (\n",
    "            2 * y * grad_rot[:, 0, 1]\n",
    "            + 2 * z * grad_rot[:, 0, 2]\n",
    "            + 2 * y * grad_rot[:, 1, 0]\n",
    "            - 4 * x * grad_rot[:, 1, 1]\n",
    "            - 2 * w * grad_rot[:, 1, 2]\n",
    "            + 2 * z * grad_rot[:, 2, 0]\n",
    "            + 2 * w * grad_rot[:, 2, 1]\n",
    "            - 4 * x * grad_rot[:, 2, 2]\n",
    "        )\n",
    "        grad_qy = (\n",
    "            -4 * y * grad_rot[:, 0, 0]\n",
    "            + 2 * x * grad_rot[:, 0, 1]\n",
    "            + 2 * w * grad_rot[:, 0, 2]\n",
    "            + 2 * x * grad_rot[:, 1, 0]\n",
    "            + 2 * z * grad_rot[:, 1, 2]\n",
    "            - 2 * w * grad_rot[:, 2, 0]\n",
    "            + 2 * z * grad_rot[:, 2, 1]\n",
    "            - 4 * y * grad_rot[:, 2, 2]\n",
    "        )\n",
    "        grad_qz = (\n",
    "            -4 * z * grad_rot[:, 0, 0]\n",
    "            - 2 * w * grad_rot[:, 0, 1]\n",
    "            + 2 * x * grad_rot[:, 0, 2]\n",
    "            + 2 * w * grad_rot[:, 1, 0]\n",
    "            - 4 * z * grad_rot[:, 1, 1]\n",
    "            + 2 * y * grad_rot[:, 1, 2]\n",
    "            + 2 * x * grad_rot[:, 2, 0]\n",
    "            + 2 * y * grad_rot[:, 2, 1]\n",
    "        )\n",
    "        grad_q = torch.stack([grad_qw, grad_qx, grad_qy, grad_qz], dim=1)\n",
    "\n",
    "        return grad_q\n",
    "\n",
    "\n",
    "q = torch.rand(10, 4, dtype=torch.float64, requires_grad=True)\n",
    "norm_q = torch.norm(q, dim=1, keepdim=True)\n",
    "q = q / norm_q\n",
    "\n",
    "test = gradcheck(QuaternionToRotation.apply, (q))\n",
    "print(test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\left[\\begin{matrix}- \\frac{w^{2}}{\\left(w^{2} + x^{2} + y^{2} + z^{2}\\right)^{\\frac{3}{2}}} + \\frac{1}{\\sqrt{w^{2} + x^{2} + y^{2} + z^{2}}}\\\\- \\frac{w x}{\\left(w^{2} + x^{2} + y^{2} + z^{2}\\right)^{\\frac{3}{2}}}\\\\- \\frac{w y}{\\left(w^{2} + x^{2} + y^{2} + z^{2}\\right)^{\\frac{3}{2}}}\\\\- \\frac{w z}{\\left(w^{2} + x^{2} + y^{2} + z^{2}\\right)^{\\frac{3}{2}}}\\end{matrix}\\right]\n",
      "\\left[\\begin{matrix}- \\frac{w x}{\\left(w^{2} + x^{2} + y^{2} + z^{2}\\right)^{\\frac{3}{2}}}\\\\- \\frac{x^{2}}{\\left(w^{2} + x^{2} + y^{2} + z^{2}\\right)^{\\frac{3}{2}}} + \\frac{1}{\\sqrt{w^{2} + x^{2} + y^{2} + z^{2}}}\\\\- \\frac{x y}{\\left(w^{2} + x^{2} + y^{2} + z^{2}\\right)^{\\frac{3}{2}}}\\\\- \\frac{x z}{\\left(w^{2} + x^{2} + y^{2} + z^{2}\\right)^{\\frac{3}{2}}}\\end{matrix}\\right]\n",
      "\\left[\\begin{matrix}- \\frac{w y}{\\left(w^{2} + x^{2} + y^{2} + z^{2}\\right)^{\\frac{3}{2}}}\\\\- \\frac{x y}{\\left(w^{2} + x^{2} + y^{2} + z^{2}\\right)^{\\frac{3}{2}}}\\\\- \\frac{y^{2}}{\\left(w^{2} + x^{2} + y^{2} + z^{2}\\right)^{\\frac{3}{2}}} + \\frac{1}{\\sqrt{w^{2} + x^{2} + y^{2} + z^{2}}}\\\\- \\frac{y z}{\\left(w^{2} + x^{2} + y^{2} + z^{2}\\right)^{\\frac{3}{2}}}\\end{matrix}\\right]\n",
      "\\left[\\begin{matrix}- \\frac{w z}{\\left(w^{2} + x^{2} + y^{2} + z^{2}\\right)^{\\frac{3}{2}}}\\\\- \\frac{x z}{\\left(w^{2} + x^{2} + y^{2} + z^{2}\\right)^{\\frac{3}{2}}}\\\\- \\frac{y z}{\\left(w^{2} + x^{2} + y^{2} + z^{2}\\right)^{\\frac{3}{2}}}\\\\- \\frac{z^{2}}{\\left(w^{2} + x^{2} + y^{2} + z^{2}\\right)^{\\frac{3}{2}}} + \\frac{1}{\\sqrt{w^{2} + x^{2} + y^{2} + z^{2}}}\\end{matrix}\\right]\n"
     ]
    }
   ],
   "source": [
    "import sympy as sp\n",
    "from sympy import print_latex\n",
    "\n",
    "w, x, y, z = sp.symbols(\"w x y z\")\n",
    "q = sp.Matrix([w, x, y, z])\n",
    "norm = sp.sqrt(w**2 + x**2 + y**2 + z**2)\n",
    "\n",
    "q_norm = q / norm\n",
    "\n",
    "dw = sp.diff(q_norm, w)\n",
    "dx = sp.diff(q_norm, x)\n",
    "dy = sp.diff(q_norm, y)\n",
    "dz = sp.diff(q_norm, z)\n",
    "\n",
    "print_latex(dw)\n",
    "print_latex(dx)\n",
    "print_latex(dy)\n",
    "print_latex(dz)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Partial Derivatives of Quaternion Normalization\n",
    "\n",
    "$$ \\frac{\\partial q}{\\partial w} = \\left[\\begin{matrix}- \\frac{w^{2}}{\\left(w^{2} + x^{2} + y^{2} + z^{2}\\right)^{\\frac{3}{2}}} + \\frac{1}{\\sqrt{w^{2} + x^{2} + y^{2} + z^{2}}}\\\\- \\frac{w x}{\\left(w^{2} + x^{2} + y^{2} + z^{2}\\right)^{\\frac{3}{2}}}\\\\- \\frac{w y}{\\left(w^{2} + x^{2} + y^{2} + z^{2}\\right)^{\\frac{3}{2}}}\\\\- \\frac{w z}{\\left(w^{2} + x^{2} + y^{2} + z^{2}\\right)^{\\frac{3}{2}}}\\end{matrix}\\right] $$\n",
    "\n",
    "$$ \\frac{\\partial q}{\\partial x} = \\left[\\begin{matrix}- \\frac{w x}{\\left(w^{2} + x^{2} + y^{2} + z^{2}\\right)^{\\frac{3}{2}}}\\\\- \\frac{x^{2}}{\\left(w^{2} + x^{2} + y^{2} + z^{2}\\right)^{\\frac{3}{2}}} + \\frac{1}{\\sqrt{w^{2} + x^{2} + y^{2} + z^{2}}}\\\\- \\frac{x y}{\\left(w^{2} + x^{2} + y^{2} + z^{2}\\right)^{\\frac{3}{2}}}\\\\- \\frac{x z}{\\left(w^{2} + x^{2} + y^{2} + z^{2}\\right)^{\\frac{3}{2}}}\\end{matrix}\\right] $$\n",
    "\n",
    "$$ \\frac{\\partial q}{\\partial y} =\\left[\\begin{matrix}- \\frac{w y}{\\left(w^{2} + x^{2} + y^{2} + z^{2}\\right)^{\\frac{3}{2}}}\\\\- \\frac{x y}{\\left(w^{2} + x^{2} + y^{2} + z^{2}\\right)^{\\frac{3}{2}}}\\\\- \\frac{y^{2}}{\\left(w^{2} + x^{2} + y^{2} + z^{2}\\right)^{\\frac{3}{2}}} + \\frac{1}{\\sqrt{w^{2} + x^{2} + y^{2} + z^{2}}}\\\\- \\frac{y z}{\\left(w^{2} + x^{2} + y^{2} + z^{2}\\right)^{\\frac{3}{2}}}\\end{matrix}\\right] $$ \n",
    "\n",
    "$$ \\frac{\\partial q}{\\partial z} = \\left[\\begin{matrix}- \\frac{w z}{\\left(w^{2} + x^{2} + y^{2} + z^{2}\\right)^{\\frac{3}{2}}}\\\\- \\frac{x z}{\\left(w^{2} + x^{2} + y^{2} + z^{2}\\right)^{\\frac{3}{2}}}\\\\- \\frac{y z}{\\left(w^{2} + x^{2} + y^{2} + z^{2}\\right)^{\\frac{3}{2}}}\\\\- \\frac{z^{2}}{\\left(w^{2} + x^{2} + y^{2} + z^{2}\\right)^{\\frac{3}{2}}} + \\frac{1}{\\sqrt{w^{2} + x^{2} + y^{2} + z^{2}}}\\end{matrix}\\right] $$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch.autograd import Function, gradcheck\n",
    "\n",
    "\n",
    "class QuaternionNormalization(Function):\n",
    "    @staticmethod\n",
    "    def forward(ctx, q):\n",
    "        q_norm = q / torch.norm(q, dim=1, keepdim=True)\n",
    "        ctx.save_for_backward(q)\n",
    "        return q_norm\n",
    "\n",
    "    @staticmethod\n",
    "    def backward(ctx, grad_q_norm):\n",
    "        q = ctx.saved_tensors[0]\n",
    "        w = q[:, 0]\n",
    "        x = q[:, 1]\n",
    "        y = q[:, 2]\n",
    "        z = q[:, 3]\n",
    "\n",
    "        norm_sq = w * w + x * x + y * y + z * z\n",
    "        grad_qw = (\n",
    "            (-1 * w * w / norm_sq**1.5 + 1 / norm_sq**0.5) * grad_q_norm[:, 0]\n",
    "            - w * x / norm_sq**1.5 * grad_q_norm[:, 1]\n",
    "            - w * y / norm_sq**1.5 * grad_q_norm[:, 2]\n",
    "            - w * z / norm_sq**1.5 * grad_q_norm[:, 3]\n",
    "        )\n",
    "        grad_qx = (\n",
    "            -w * x / norm_sq**1.5 * grad_q_norm[:, 0]\n",
    "            + (-1 * x * x / norm_sq**1.5 + 1 / norm_sq**0.5) * grad_q_norm[:, 1]\n",
    "            - x * y / norm_sq**1.5 * grad_q_norm[:, 2]\n",
    "            - x * z / norm_sq**1.5 * grad_q_norm[:, 3]\n",
    "        )\n",
    "        grad_qy = (\n",
    "            -w * y / norm_sq**1.5 * grad_q_norm[:, 0]\n",
    "            - x * y / norm_sq**1.5 * grad_q_norm[:, 1]\n",
    "            + (-1 * y * y / norm_sq**1.5 + 1 / norm_sq**0.5) * grad_q_norm[:, 2]\n",
    "            - y * z / norm_sq**1.5 * grad_q_norm[:, 3]\n",
    "        )\n",
    "        grad_qz = (\n",
    "            -w * z / norm_sq**1.5 * grad_q_norm[:, 0]\n",
    "            - x * z / norm_sq**1.5 * grad_q_norm[:, 1]\n",
    "            - y * z / norm_sq**1.5 * grad_q_norm[:, 2]\n",
    "            + (-1 * z * z / norm_sq**1.5 + 1 / norm_sq**0.5) * grad_q_norm[:, 3]\n",
    "        )\n",
    "        grad_q = torch.stack([grad_qw, grad_qx, grad_qy, grad_qz], dim=1)\n",
    "\n",
    "        return grad_q\n",
    "\n",
    "\n",
    "q = torch.rand(2, 4, dtype=torch.float64, requires_grad=True)\n",
    "\n",
    "test = gradcheck(QuaternionNormalization.apply, (q))\n",
    "print(test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch.autograd import Function, gradcheck\n",
    "\n",
    "\n",
    "class ComputeSigmaWorld(Function):\n",
    "    @staticmethod\n",
    "    def forward(ctx, q, scale):\n",
    "        S = torch.diag_embed(torch.exp(scale))\n",
    "        norm_q = torch.norm(q, dim=1, keepdim=True)\n",
    "        q_norm = q / norm_q\n",
    "        R = [\n",
    "            1 - 2 * q_norm[:, 2] ** 2 - 2 * q_norm[:, 3] ** 2,\n",
    "            2 * q_norm[:, 1] * q_norm[:, 2] - 2 * q_norm[:, 0] * q_norm[:, 3],\n",
    "            2 * q_norm[:, 3] * q_norm[:, 1] + 2 * q_norm[:, 0] * q_norm[:, 2],\n",
    "            2 * q_norm[:, 1] * q_norm[:, 2] + 2 * q_norm[:, 0] * q_norm[:, 3],\n",
    "            1 - 2 * q_norm[:, 1] ** 2 - 2 * q_norm[:, 3] ** 2,\n",
    "            2 * q_norm[:, 2] * q_norm[:, 3] - 2 * q_norm[:, 0] * q_norm[:, 1],\n",
    "            2 * q_norm[:, 3] * q_norm[:, 1] - 2 * q_norm[:, 0] * q_norm[:, 2],\n",
    "            2 * q_norm[:, 2] * q_norm[:, 3] + 2 * q_norm[:, 0] * q_norm[:, 1],\n",
    "            1 - 2 * q_norm[:, 1] ** 2 - 2 * q_norm[:, 2] ** 2,\n",
    "        ]\n",
    "        R = torch.stack(R, dim=1).reshape(-1, 3, 3)\n",
    "\n",
    "        RS = torch.bmm(R, S)\n",
    "        RS_t = RS.permute(0, 2, 1)\n",
    "\n",
    "        RSSR = torch.bmm(RS, RS_t)\n",
    "        ctx.save_for_backward(RS, R, S, scale, q, q_norm)\n",
    "        return RSSR\n",
    "\n",
    "    @staticmethod\n",
    "    def backward(ctx, grad_RSSR):\n",
    "        # compute double matmul gradient\n",
    "        RS, R, S, scale, q, q_norm = ctx.saved_tensors\n",
    "        grad_RS = torch.bmm(grad_RSSR, RS)\n",
    "\n",
    "        RS_t = RS.permute(0, 2, 1)\n",
    "        grad_SR = RS_t @ grad_RSSR\n",
    "\n",
    "        grad_R = grad_RS @ S.permute(0, 2, 1) + (S @ grad_SR).permute(0, 2, 1)\n",
    "        grad_S = R.permute(0, 2, 1) @ grad_RS + (grad_SR @ R).permute(0, 2, 1)\n",
    "\n",
    "        # compute quaternion gradient\n",
    "        w = q_norm[:, 0]\n",
    "        x = q_norm[:, 1]\n",
    "        y = q_norm[:, 2]\n",
    "        z = q_norm[:, 3]\n",
    "        grad_qw_norm = (\n",
    "            -2 * z * grad_R[:, 0, 1]\n",
    "            + 2 * y * grad_R[:, 0, 2]\n",
    "            + 2 * z * grad_R[:, 1, 0]\n",
    "            - 2 * x * grad_R[:, 1, 2]\n",
    "            - 2 * y * grad_R[:, 2, 0]\n",
    "            + 2 * x * grad_R[:, 2, 1]\n",
    "        )\n",
    "        grad_qx_norm = (\n",
    "            2 * y * grad_R[:, 0, 1]\n",
    "            + 2 * z * grad_R[:, 0, 2]\n",
    "            + 2 * y * grad_R[:, 1, 0]\n",
    "            - 4 * x * grad_R[:, 1, 1]\n",
    "            - 2 * w * grad_R[:, 1, 2]\n",
    "            + 2 * z * grad_R[:, 2, 0]\n",
    "            + 2 * w * grad_R[:, 2, 1]\n",
    "            - 4 * x * grad_R[:, 2, 2]\n",
    "        )\n",
    "        grad_qy_norm = (\n",
    "            -4 * y * grad_R[:, 0, 0]\n",
    "            + 2 * x * grad_R[:, 0, 1]\n",
    "            + 2 * w * grad_R[:, 0, 2]\n",
    "            + 2 * x * grad_R[:, 1, 0]\n",
    "            + 2 * z * grad_R[:, 1, 2]\n",
    "            - 2 * w * grad_R[:, 2, 0]\n",
    "            + 2 * z * grad_R[:, 2, 1]\n",
    "            - 4 * y * grad_R[:, 2, 2]\n",
    "        )\n",
    "        grad_qz_norm = (\n",
    "            -4 * z * grad_R[:, 0, 0]\n",
    "            - 2 * w * grad_R[:, 0, 1]\n",
    "            + 2 * x * grad_R[:, 0, 2]\n",
    "            + 2 * w * grad_R[:, 1, 0]\n",
    "            - 4 * z * grad_R[:, 1, 1]\n",
    "            + 2 * y * grad_R[:, 1, 2]\n",
    "            + 2 * x * grad_R[:, 2, 0]\n",
    "            + 2 * y * grad_R[:, 2, 1]\n",
    "        )\n",
    "        grad_q_norm = torch.stack([grad_qw_norm, grad_qx_norm, grad_qy_norm, grad_qz_norm], dim=1)\n",
    "\n",
    "        # compute gradient for unnormalized quaternion\n",
    "        w = q[:, 0]\n",
    "        x = q[:, 1]\n",
    "        y = q[:, 2]\n",
    "        z = q[:, 3]\n",
    "        norm_sq = w * w + x * x + y * y + z * z\n",
    "        grad_qw = (\n",
    "            (-1 * w * w / norm_sq**1.5 + 1 / norm_sq**0.5) * grad_q_norm[:, 0]\n",
    "            - w * x / norm_sq**1.5 * grad_q_norm[:, 1]\n",
    "            - w * y / norm_sq**1.5 * grad_q_norm[:, 2]\n",
    "            - w * z / norm_sq**1.5 * grad_q_norm[:, 3]\n",
    "        )\n",
    "        grad_qx = (\n",
    "            -w * x / norm_sq**1.5 * grad_q_norm[:, 0]\n",
    "            + (-1 * x * x / norm_sq**1.5 + 1 / norm_sq**0.5) * grad_q_norm[:, 1]\n",
    "            - x * y / norm_sq**1.5 * grad_q_norm[:, 2]\n",
    "            - x * z / norm_sq**1.5 * grad_q_norm[:, 3]\n",
    "        )\n",
    "        grad_qy = (\n",
    "            -w * y / norm_sq**1.5 * grad_q_norm[:, 0]\n",
    "            - x * y / norm_sq**1.5 * grad_q_norm[:, 1]\n",
    "            + (-1 * y * y / norm_sq**1.5 + 1 / norm_sq**0.5) * grad_q_norm[:, 2]\n",
    "            - y * z / norm_sq**1.5 * grad_q_norm[:, 3]\n",
    "        )\n",
    "        grad_qz = (\n",
    "            -w * z / norm_sq**1.5 * grad_q_norm[:, 0]\n",
    "            - x * z / norm_sq**1.5 * grad_q_norm[:, 1]\n",
    "            - y * z / norm_sq**1.5 * grad_q_norm[:, 2]\n",
    "            + (-1 * z * z / norm_sq**1.5 + 1 / norm_sq**0.5) * grad_q_norm[:, 3]\n",
    "        )\n",
    "        grad_q = torch.stack([grad_qw, grad_qx, grad_qy, grad_qz], dim=1)\n",
    "\n",
    "        grad_scale_no_activation = grad_S.diagonal(dim1=1, dim2=2)\n",
    "        grad_scale = grad_scale_no_activation * torch.exp(scale)\n",
    "\n",
    "        return grad_q, grad_scale\n",
    "\n",
    "\n",
    "N = 2\n",
    "q = torch.rand(N, 4, dtype=torch.float64, requires_grad=True)\n",
    "\n",
    "s = torch.rand(N, 3, dtype=torch.float64, requires_grad=True)\n",
    "test = gradcheck(ComputeSigmaWorld.apply, (q, s))\n",
    "print(test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\frac{a v^{2} - 2 b u v + c u^{2}}{a c - 2 b}\n",
      "- \\frac{a \\left(a v^{2} - 2 b u v + c u^{2}\\right)}{\\left(a c - 2 b\\right)^{2}} + \\frac{u^{2}}{a c - 2 b}\n"
     ]
    }
   ],
   "source": [
    "import sympy as sp\n",
    "from sympy import print_latex, exp\n",
    "\n",
    "\n",
    "a, b, c = sp.symbols(\"a b c\")\n",
    "sigma_image = sp.Matrix([[a, b], [b, c]])\n",
    "\n",
    "u, v = sp.symbols(\"u v\")\n",
    "\n",
    "mh_dist_sq = (c * u**2 - b * u * v - b * u * v + a * v**2) / (a * c - 2 * b)\n",
    "\n",
    "print_latex(mh_dist_sq)\n",
    "print_latex(sp.diff(mh_dist_sq, c))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "$$ d_m^2 = \\frac{a v^{2} - b u v - c u v + d u^{2}}{a d - b c} $$ \n",
    "\n",
    "$$ \\frac{\\partial d_m^2}{\\partial a} = - \\frac{d \\left(a v^{2} - b u v - c u v + d u^{2}\\right)}{\\left(a d - b c\\right)^{2}} + \\frac{v^{2}}{a d - b c} $$ \n",
    "\n",
    "$$ \\frac{\\partial d_m^2}{\\partial b} = \\frac{c \\left(a v^{2} - b u v - c u v + d u^{2}\\right)}{\\left(a d - b c\\right)^{2}} - \\frac{u v}{a d - b c} $$ \n",
    "\n",
    "$$ \\frac{\\partial d_m^2}{\\partial c} = \\frac{b \\left(a v^{2} - b u v - c u v + d u^{2}\\right)}{\\left(a d - b c\\right)^{2}} - \\frac{u v}{a d - b c} $$ \n",
    "\n",
    "$$ \\frac{\\partial d_m^2}{\\partial d} = - \\frac{a \\left(a v^{2} - b u v - c u v + d u^{2}\\right)}{\\left(a d - b c\\right)^{2}} + \\frac{u^{2}}{a d - b c} $$\n",
    "\n",
    "\n",
    "$$ \\frac{\\partial d_m^2}{\\partial u} = \\frac{- b v - c v + 2 d u}{a d - b c} $$ \n",
    "\n",
    "$$ \\frac{\\partial d_m^2}{\\partial v} = \\frac{2 a v - b u - c u}{a d - b c} $$ \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch.autograd import Function, gradcheck\n",
    "\n",
    "\n",
    "class ComputeAlpha(Function):\n",
    "    @staticmethod\n",
    "    def forward(ctx, sigma_image, opa, uv_splat, uv_pixel):\n",
    "        uv_diff = uv_pixel - uv_splat\n",
    "        a = sigma_image[0, 0]\n",
    "        b = sigma_image[0, 1]\n",
    "        c = sigma_image[1, 0]\n",
    "        d = sigma_image[1, 1]\n",
    "        mh_dist = (\n",
    "            d * uv_diff[0] ** 2\n",
    "            - b * uv_diff[0] * uv_diff[1]\n",
    "            - c * uv_diff[0] * uv_diff[1]\n",
    "            + a * uv_diff[1] ** 2\n",
    "        ) / (a * d - b * c)\n",
    "\n",
    "        prob = torch.exp(-0.5 * mh_dist)\n",
    "        alpha = prob * opa\n",
    "        ctx.save_for_backward(prob, sigma_image, uv_diff, opa)\n",
    "        return alpha\n",
    "\n",
    "    @staticmethod\n",
    "    def backward(ctx, grad_alpha):\n",
    "        prob, sigma_image, uv_diff, opa = ctx.saved_tensors\n",
    "        grad_opa = prob * grad_alpha\n",
    "\n",
    "        ## compute sigma world and uv_diff gradients\n",
    "        grad_prob = opa * grad_alpha\n",
    "        grad_mh = -0.5 * prob * grad_prob\n",
    "\n",
    "        a = sigma_image[0, 0]\n",
    "        b = sigma_image[0, 1]\n",
    "        c = sigma_image[1, 0]\n",
    "        d = sigma_image[1, 1]\n",
    "\n",
    "        u_diff = uv_diff[0]\n",
    "        v_diff = uv_diff[1]\n",
    "\n",
    "        grad_u = -(-b * v_diff - c * v_diff + 2 * d * u_diff) / (a * d - b * c) * grad_mh\n",
    "        grad_v = -(2 * a * v_diff - b * u_diff - c * u_diff) / (a * d - b * c) * grad_mh\n",
    "\n",
    "        grad_a = (\n",
    "            -d\n",
    "            * (a * v_diff**2 - b * u_diff * v_diff - c * u_diff * v_diff + d * u_diff**2)\n",
    "            / (a * d - b * c) ** 2\n",
    "            + (v_diff**2) / (a * d - b * c)\n",
    "        ) * grad_mh\n",
    "        grad_b = (\n",
    "            c\n",
    "            * (a * v_diff**2 - b * u_diff * v_diff - c * u_diff * v_diff + d * u_diff**2)\n",
    "            / (a * d - b * c) ** 2\n",
    "            - (u_diff * v_diff) / (a * d - b * c)\n",
    "        ) * grad_mh\n",
    "        grad_c = (\n",
    "            b\n",
    "            * (a * v_diff**2 - b * u_diff * v_diff - c * u_diff * v_diff + d * u_diff**2)\n",
    "            / (a * d - b * c) ** 2\n",
    "            - (u_diff * v_diff) / (a * d - b * c)\n",
    "        ) * grad_mh\n",
    "        grad_d = (\n",
    "            -a\n",
    "            * (a * v_diff**2 - b * u_diff * v_diff - c * u_diff * v_diff + d * u_diff**2)\n",
    "            / (a * d - b * c) ** 2\n",
    "            + (u_diff**2) / (a * d - b * c)\n",
    "        ) * grad_mh\n",
    "\n",
    "        grad_sigma_image = torch.Tensor([[grad_a, grad_b], [grad_c, grad_d]])\n",
    "        grad_uv_splat = torch.Tensor([grad_u, grad_v])\n",
    "\n",
    "        return grad_sigma_image, grad_opa, grad_uv_splat, None\n",
    "\n",
    "\n",
    "uv_splat = torch.rand(2, dtype=torch.float64, requires_grad=True)\n",
    "uv_pixel = torch.rand(2, dtype=torch.float64, requires_grad=False)\n",
    "\n",
    "sigma_image = torch.rand(2, 2, dtype=torch.float64, requires_grad=True)\n",
    "opa = torch.rand(1, dtype=torch.float64, requires_grad=True)\n",
    "test = gradcheck(ComputeAlpha.apply, (sigma_image, opa, uv_splat, uv_pixel))\n",
    "print(test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Alpha Compositing\n",
    "\n",
    "First (front-to-back) Gaussian Splatted to Pixel: \n",
    "$$ \\alpha_a = 0.0 $$ \n",
    "$$ \\alpha_c = \\alpha_0(1.0 - \\alpha_a) = \\alpha_0 $$ \n",
    "\n",
    "Second splat:\n",
    "$$ \\alpha_a = \\alpha_0 $$\n",
    "$$ \\alpha_c = \\alpha_1(1.0 - \\alpha_a) = \\alpha_1(1.0 - \\alpha_0) $$ \n",
    "\n",
    "Third splat:\n",
    "$$ \\alpha_a = \\alpha_0 + \\alpha_1(1.0 - \\alpha_0) $$\n",
    "$$ \\alpha_c = \\alpha_2(1.0 - \\alpha_a) = \\alpha_2(1.0 - (\\alpha_0 + \\alpha_1(1.0 - \\alpha_0))) $$ \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch.autograd import Function, gradcheck\n",
    "\n",
    "\n",
    "class AlphaComposite(Function):\n",
    "    @staticmethod\n",
    "    def forward(ctx, colors, alphas):\n",
    "        alpha_accum = 0.0\n",
    "        color_accum = torch.zeros_like(colors[0])\n",
    "        for i in range(alphas.shape[0]):\n",
    "            alpha_weight = 1 - alpha_accum\n",
    "            alpha_current = alphas[i] * (1 - alpha_accum)\n",
    "            color_accum += alpha_current * colors[i, :]\n",
    "            alpha_accum += alpha_current\n",
    "\n",
    "        ctx.save_for_backward(alpha_weight, alphas, colors)\n",
    "        return color_accum\n",
    "\n",
    "    @staticmethod\n",
    "    def backward(ctx, grad_color_accum):\n",
    "        weight_final, alphas, colors = ctx.saved_tensors\n",
    "        grad_alphas = torch.zeros_like(alphas)\n",
    "        grad_colors = torch.zeros_like(colors)\n",
    "\n",
    "        colors_accum = torch.zeros_like(colors[0])\n",
    "        weight = weight_final\n",
    "        for i in reversed(range(alphas.shape[0])):\n",
    "            grad_colors[i] = alphas[i] * weight * grad_color_accum\n",
    "            grad_alphas[i] = torch.dot(\n",
    "                (colors[i, :] * weight - colors_accum / (1.0 - alphas[i])), grad_color_accum\n",
    "            )\n",
    "\n",
    "            colors_accum = colors_accum + alphas[i] * colors[i, :] * weight\n",
    "            weight = weight / (1 - alphas[i - 1])\n",
    "\n",
    "        return grad_colors, grad_alphas\n",
    "\n",
    "\n",
    "alphas = torch.rand(10, dtype=torch.float64, requires_grad=True) / 10.0\n",
    "colors = torch.rand(10, 3, dtype=torch.float64, requires_grad=True)\n",
    "\n",
    "test = gradcheck(AlphaComposite.apply, (colors, alphas))\n",
    "print(test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch.autograd import Function, gradcheck\n",
    "\n",
    "\n",
    "class AlphaCompositeWithBackground(Function):\n",
    "    @staticmethod\n",
    "    def forward(ctx, colors, alphas, background_rgb):\n",
    "        alpha_accum = 0.0\n",
    "        color_accum = torch.zeros_like(colors[0])\n",
    "        for i in range(alphas.shape[0]):\n",
    "            alpha_weight = 1 - alpha_accum\n",
    "            alpha_current = alphas[i] * (1 - alpha_accum)\n",
    "            color_accum += alpha_current * colors[i, :]\n",
    "            alpha_accum += alpha_current\n",
    "\n",
    "        color_accum += (1.0 - alpha_accum) * background_rgb\n",
    "\n",
    "        ctx.save_for_backward(alpha_weight, alphas, colors, background_rgb)\n",
    "        return color_accum\n",
    "\n",
    "    @staticmethod\n",
    "    def backward(ctx, grad_color_accum):\n",
    "        weight_final, alphas, colors, background_rgb = ctx.saved_tensors\n",
    "        grad_alphas = torch.zeros_like(alphas)\n",
    "        grad_colors = torch.zeros_like(colors)\n",
    "\n",
    "        colors_accum = torch.zeros_like(colors[0])\n",
    "        weight = weight_final\n",
    "\n",
    "        final_alpha_current = alphas[-1] * weight\n",
    "        alpha_current_prev = 1.0 - weight\n",
    "\n",
    "        bg_weight = 1.0 - (final_alpha_current + alpha_current_prev)\n",
    "\n",
    "        colors_accum += bg_weight * background_rgb\n",
    "        for i in reversed(range(alphas.shape[0])):\n",
    "            grad_colors[i] = alphas[i] * weight * grad_color_accum\n",
    "            grad_alphas[i] = torch.dot(\n",
    "                (colors[i, :] * weight - colors_accum / (1.0 - alphas[i])), grad_color_accum\n",
    "            )\n",
    "\n",
    "            colors_accum = colors_accum + alphas[i] * colors[i, :] * weight\n",
    "            weight = weight / (1 - alphas[i - 1])\n",
    "\n",
    "        return grad_colors, grad_alphas, None\n",
    "\n",
    "\n",
    "alphas = torch.rand(10, dtype=torch.float64, requires_grad=True) / 10.0\n",
    "colors = torch.rand(10, 3, dtype=torch.float64, requires_grad=True)\n",
    "background_rgb = torch.ones(3, dtype=torch.float64, requires_grad=False)\n",
    "\n",
    "test = gradcheck(AlphaCompositeWithBackground.apply, (colors, alphas, background_rgb))\n",
    "print(test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Spherical Harmonics\n",
    "\n",
    "\n",
    "Zero Order - not dependent on viewing direction\n",
    "\n",
    "$ Y_0 = \\frac{1}{2} \\sqrt{\\frac{1}{\\pi}}$\n",
    "\n",
    "First order\n",
    "\n",
    "$ Y_1^{-1} = \\frac{1}{2} \\sqrt{\\frac{3}{2\\pi}} \\frac{y}{r}$\n",
    "\n",
    "$ Y_1^{0} = \\frac{1}{2} \\sqrt{\\frac{3}{2\\pi}} \\frac{z}{r}$\n",
    "\n",
    "$ Y_1^{1} = \\frac{1}{2} \\sqrt{\\frac{3}{2\\pi}} \\frac{x}{r}$\n",
    "\n",
    "\n",
    "Second Order\n",
    "\n",
    "$ Y_2^{-2} = \\frac{1}{2} \\sqrt{\\frac{15}{\\pi}} \\frac{xy}{r^2}$\n",
    "\n",
    "$ Y_2^{-1} = \\frac{1}{2} \\sqrt{\\frac{15}{\\pi}} \\frac{yz}{r^2}$\n",
    "\n",
    "$ Y_2^{0} = \\frac{1}{4} \\sqrt{\\frac{5}{\\pi}} \\frac{3z^2 - r^2}{r^2}$\n",
    "\n",
    "$ Y_2^{1} = \\frac{1}{2} \\sqrt{\\frac{15}{\\pi}} \\frac{xz}{r^2}$\n",
    "\n",
    "$ Y_2^{2} = \\frac{1}{4} \\sqrt{\\frac{15}{\\pi}} \\frac{x^2 - y^2}{r^2}$\n",
    "\n",
    "\n",
    "Third Order\n",
    "\n",
    "$ Y_3^{-3} = \\frac{1}{4} \\sqrt{\\frac{35}{2\\pi}} \\frac{y(3x^2-y^2)}{r^3}$\n",
    "\n",
    "$ Y_3^{-2} = \\frac{1}{2} \\sqrt{\\frac{105}{\\pi}} \\frac{xyz}{r^3}$\n",
    "\n",
    "$ Y_3^{-1} = \\frac{1}{4} \\sqrt{\\frac{21}{2\\pi}} \\frac{y(5z^2-r^2)}{r^3}$\n",
    "\n",
    "$ Y_3^{0} = \\frac{1}{4} \\sqrt{\\frac{7}{\\pi}} \\frac{z(5z^2-3r^2)}{r^3}$\n",
    "\n",
    "$ Y_3^{1} = \\frac{1}{4} \\sqrt{\\frac{21}{2\\pi}} \\frac{x(5z^2-r^2)}{r^3}$\n",
    "\n",
    "$ Y_3^{2} = \\frac{1}{4} \\sqrt{\\frac{105}{\\pi}} \\frac{z(x^2-y^2)}{r^3}$\n",
    "\n",
    "$ Y_3^{3} = \\frac{1}{4} \\sqrt{\\frac{35}{2\\pi}} \\frac{x(x^2-3y^2)}{r^3}$\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Band 0 0.28209479177387814\n",
      "Band 1 0.4886025119029199\n",
      "Band 2 1.0925484305920792 -1.0925484305920792 0.31539156525252005 -1.0925484305920792 0.5462742152960396\n",
      "Band 3 0.5900435899266435 2.890611442640554 0.4570457994644658 0.263875515352797 0.4570457994644658 1.445305721320277 0.5900435899266435\n"
     ]
    }
   ],
   "source": [
    "# compute the constants for the SH bands\n",
    "import math\n",
    "\n",
    "C_SH_0 = math.sqrt(1 / math.pi) / 2\n",
    "C_SH_1 = math.sqrt(3 / math.pi) / 2\n",
    "\n",
    "C_SH_2_N2 = math.sqrt(15 / math.pi) / 2\n",
    "C_SH_2_N1 = -1.0 * math.sqrt(15 / math.pi) / 2\n",
    "C_SH_2_0 = math.sqrt(5 / math.pi) / 4\n",
    "C_SH_2_P1 = -1.0 * math.sqrt(15 / math.pi) / 2\n",
    "C_SH_2_P2 = math.sqrt(15 / math.pi) / 4\n",
    "\n",
    "C_SH_3_N3 = math.sqrt(35 / (2 * math.pi)) / 4\n",
    "C_SH_3_N2 = math.sqrt(105 / (math.pi)) / 2\n",
    "C_SH_3_N1 = math.sqrt(21 / (2 * math.pi)) / 4\n",
    "C_SH_3_0 = math.sqrt(7 / (2 * math.pi)) / 4\n",
    "C_SH_3_P1 = math.sqrt(21 / (2 * math.pi)) / 4\n",
    "C_SH_3_P2 = math.sqrt(105 / (math.pi)) / 4\n",
    "C_SH_3_P3 = math.sqrt(35 / (2 * math.pi)) / 4\n",
    "\n",
    "\n",
    "print(\"Band 0\", C_SH_0)\n",
    "print(\"Band 1\", C_SH_1)\n",
    "print(\"Band 2\", C_SH_2_N2, C_SH_2_N1, C_SH_2_0, C_SH_2_P1, C_SH_2_P2)\n",
    "print(\"Band 3\", C_SH_3_N3, C_SH_3_N2, C_SH_3_N1, C_SH_3_0, C_SH_3_P1, C_SH_3_P2, C_SH_3_P3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch.autograd import Function, gradcheck\n",
    "\n",
    "\n",
    "class SphericalHarmonicsToRGB(Function):\n",
    "    @staticmethod\n",
    "    def forward(ctx, sh_coeff, view_dir):\n",
    "        view_dir = view_dir / torch.norm(view_dir, dim=1, keepdim=True)\n",
    "        x = view_dir[:, 0]\n",
    "        y = view_dir[:, 1]\n",
    "        z = view_dir[:, 2]\n",
    "\n",
    "        N = sh_coeff.shape[0]\n",
    "        N_sh = sh_coeff.shape[2]\n",
    "\n",
    "        sh_r = sh_coeff[:, 0, :]\n",
    "        sh_g = sh_coeff[:, 1, :]\n",
    "        sh_b = sh_coeff[:, 2, :]\n",
    "\n",
    "        rgb = torch.zeros(N, 3, dtype=sh_coeff.dtype, device=sh_coeff.device)\n",
    "        rgb[:, 0] = C_SH_0 * sh_r[:, 0]\n",
    "        rgb[:, 1] = C_SH_0 * sh_g[:, 0]\n",
    "        rgb[:, 2] = C_SH_0 * sh_b[:, 0]\n",
    "\n",
    "        if N_sh >= 4:\n",
    "            rgb[:, 0] += (\n",
    "                -1 * C_SH_1 * sh_r[:, 1] * x + C_SH_1 * sh_r[:, 2] * y - C_SH_1 * sh_r[:, 3] * z\n",
    "            )\n",
    "            rgb[:, 1] += (\n",
    "                -1 * C_SH_1 * sh_g[:, 1] * x + C_SH_1 * sh_g[:, 2] * y - C_SH_1 * sh_g[:, 3] * z\n",
    "            )\n",
    "            rgb[:, 2] += (\n",
    "                -1 * C_SH_1 * sh_b[:, 1] * x + C_SH_1 * sh_b[:, 2] * y - C_SH_1 * sh_b[:, 3] * z\n",
    "            )\n",
    "        if N_sh >= 9:\n",
    "            xx = torch.square(x)\n",
    "            yy = torch.square(y)\n",
    "            zz = torch.square(z)\n",
    "            rgb[:, 0] += (\n",
    "                C_SH_2_N2 * sh_r[:, 4] * x * y\n",
    "                + C_SH_2_N1 * sh_r[:, 5] * y * z\n",
    "                + C_SH_2_0\n",
    "                * sh_r[:, 6]\n",
    "                * (3 * zz - torch.ones(N, dtype=sh_coeff.dtype, device=sh_coeff.device))\n",
    "                + C_SH_2_P1 * sh_r[:, 7] * x * z\n",
    "                + C_SH_2_P2 * sh_r[:, 8] * (xx - yy)\n",
    "            )\n",
    "\n",
    "            rgb[:, 1] += (\n",
    "                C_SH_2_N2 * sh_g[:, 4] * x * y\n",
    "                + C_SH_2_N1 * sh_g[:, 5] * y * z\n",
    "                + C_SH_2_0\n",
    "                * sh_g[:, 6]\n",
    "                * (3 * zz - torch.ones(N, dtype=sh_coeff.dtype, device=sh_coeff.device))\n",
    "                + C_SH_2_P1 * sh_g[:, 7] * x * z\n",
    "                + C_SH_2_P2 * sh_g[:, 8] * (xx - yy)\n",
    "            )\n",
    "\n",
    "            rgb[:, 2] += (\n",
    "                C_SH_2_N2 * sh_b[:, 4] * x * y\n",
    "                + C_SH_2_N1 * sh_b[:, 5] * y * z\n",
    "                + C_SH_2_0\n",
    "                * sh_b[:, 6]\n",
    "                * (3 * zz - torch.ones(N, dtype=sh_coeff.dtype, device=sh_coeff.device))\n",
    "                + C_SH_2_P1 * sh_b[:, 7] * x * z\n",
    "                + C_SH_2_P2 * sh_b[:, 8] * (xx - yy)\n",
    "            )\n",
    "\n",
    "        if N_sh >= 16:\n",
    "            rgb[:, 0] += (\n",
    "                C_SH_3_N3 * sh_r[:, 9] * (y * (3 * xx - yy))\n",
    "                + C_SH_3_N2 * sh_r[:, 10] * (x * y * z)\n",
    "                + C_SH_3_N1 * sh_r[:, 11] * (y * (5 * zz - 1))\n",
    "                + C_SH_3_0 * sh_r[:, 12] * (z * (5 * zz - 1))\n",
    "                + C_SH_3_P1 * sh_r[:, 13] * (x * (5 * zz - 1))\n",
    "                + C_SH_3_P2 * sh_r[:, 14] * (z * (xx - yy))\n",
    "                + C_SH_3_P3 * sh_r[:, 15] * (x * (xx - 3 * yy))\n",
    "            )\n",
    "\n",
    "            rgb[:, 1] += (\n",
    "                C_SH_3_N3 * sh_g[:, 9] * (y * (3 * xx - yy))\n",
    "                + C_SH_3_N2 * sh_g[:, 10] * (x * y * z)\n",
    "                + C_SH_3_N1 * sh_g[:, 11] * (y * (5 * zz - 1))\n",
    "                + C_SH_3_0 * sh_g[:, 12] * (z * (5 * zz - 1))\n",
    "                + C_SH_3_P1 * sh_g[:, 13] * (x * (5 * zz - 1))\n",
    "                + C_SH_3_P2 * sh_g[:, 14] * (z * (xx - yy))\n",
    "                + C_SH_3_P3 * sh_g[:, 15] * (x * (xx - 3 * yy))\n",
    "            )\n",
    "\n",
    "            rgb[:, 2] += (\n",
    "                C_SH_3_N3 * sh_b[:, 9] * (y * (3 * xx - yy))\n",
    "                + C_SH_3_N2 * sh_b[:, 10] * (x * y * z)\n",
    "                + C_SH_3_N1 * sh_b[:, 11] * (y * (5 * zz - 1))\n",
    "                + C_SH_3_0 * sh_b[:, 12] * (z * (5 * zz - 1))\n",
    "                + C_SH_3_P1 * sh_b[:, 13] * (x * (5 * zz - 1))\n",
    "                + C_SH_3_P2 * sh_b[:, 14] * (z * (xx - yy))\n",
    "                + C_SH_3_P3 * sh_b[:, 15] * (x * (xx - 3 * yy))\n",
    "            )\n",
    "\n",
    "        # apply sigmoid activation to constrain values between 0 and 1\n",
    "        rgb_sigmoid = torch.sigmoid(rgb)\n",
    "\n",
    "        ctx.save_for_backward(sh_coeff, view_dir, rgb_sigmoid)\n",
    "        return rgb_sigmoid\n",
    "\n",
    "    @staticmethod\n",
    "    def backward(ctx, grad_rgb_sigmoid):\n",
    "        sh_coeff, view_dir, rgb_sigmoid = ctx.saved_tensors\n",
    "        x = view_dir[:, 0]\n",
    "        y = view_dir[:, 1]\n",
    "        z = view_dir[:, 2]\n",
    "\n",
    "        N = sh_coeff.shape[0]\n",
    "        N_sh = sh_coeff.shape[2]\n",
    "\n",
    "        grad_sh_coeff = torch.zeros_like(sh_coeff)\n",
    "\n",
    "        # backwards of sigmoid\n",
    "        grad_rgb = rgb_sigmoid * (1 - rgb_sigmoid) * grad_rgb_sigmoid\n",
    "\n",
    "        # zero order gradients\n",
    "        for i in range(3):\n",
    "            grad_sh_coeff[:, i, 0] = C_SH_0 * grad_rgb[:, i]\n",
    "\n",
    "        if N_sh >= 4:\n",
    "            sh_1_n1_grad_mult = -1 * C_SH_1 * x\n",
    "            sh_1_0_grad_mult = C_SH_1 * y\n",
    "            sh_1_p1_grad_mult = -1 * C_SH_1 * z\n",
    "\n",
    "            for i in range(3):\n",
    "                grad_sh_coeff[:, i, 1] = sh_1_n1_grad_mult * grad_rgb[:, i]\n",
    "                grad_sh_coeff[:, i, 2] = sh_1_0_grad_mult * grad_rgb[:, i]\n",
    "                grad_sh_coeff[:, i, 3] = sh_1_p1_grad_mult * grad_rgb[:, i]\n",
    "\n",
    "        if N_sh >= 9:\n",
    "            xx = torch.square(x)\n",
    "            yy = torch.square(y)\n",
    "            zz = torch.square(z)\n",
    "\n",
    "            sh_2_n2_grad_mult = C_SH_2_N2 * x * y\n",
    "            sh_2_n1_grad_mult = C_SH_2_N1 * y * z\n",
    "            sh_2_0_grad_mult = C_SH_2_0 * (\n",
    "                3 * zz - torch.ones(z.shape[0], dtype=z.dtype, device=z.device)\n",
    "            )\n",
    "            sh_2_p1_grad_mult = C_SH_2_P1 * x * z\n",
    "            sh_2_p2_grad_mult = C_SH_2_P2 * (xx - yy)\n",
    "\n",
    "            for i in range(3):\n",
    "                grad_sh_coeff[:, i, 4] = sh_2_n2_grad_mult * grad_rgb[:, i]\n",
    "                grad_sh_coeff[:, i, 5] = sh_2_n1_grad_mult * grad_rgb[:, i]\n",
    "                grad_sh_coeff[:, i, 6] = sh_2_0_grad_mult * grad_rgb[:, i]\n",
    "                grad_sh_coeff[:, i, 7] = sh_2_p1_grad_mult * grad_rgb[:, i]\n",
    "                grad_sh_coeff[:, i, 8] = sh_2_p2_grad_mult * grad_rgb[:, i]\n",
    "\n",
    "        if N_sh >= 16:\n",
    "            for i in range(3):\n",
    "                grad_sh_coeff[:, i, 9] = C_SH_3_N3 * (y * (3 * xx - yy)) * grad_rgb[:, i]\n",
    "                grad_sh_coeff[:, i, 10] = C_SH_3_N2 * (x * y * z) * grad_rgb[:, i]\n",
    "                grad_sh_coeff[:, i, 11] = C_SH_3_N1 * (y * (5 * zz - 1)) * grad_rgb[:, i]\n",
    "                grad_sh_coeff[:, i, 12] = C_SH_3_0 * (z * (5 * zz - 1)) * grad_rgb[:, i]\n",
    "                grad_sh_coeff[:, i, 13] = C_SH_3_P1 * (x * (5 * zz - 1)) * grad_rgb[:, i]\n",
    "                grad_sh_coeff[:, i, 14] = C_SH_3_P2 * (z * (xx - yy)) * grad_rgb[:, i]\n",
    "                grad_sh_coeff[:, i, 15] = C_SH_3_P3 * (x * (xx - 3 * yy)) * grad_rgb[:, i]\n",
    "\n",
    "        return grad_sh_coeff, None\n",
    "\n",
    "\n",
    "sh_coeff = torch.rand(10, 3, 16, dtype=torch.float64, requires_grad=True)\n",
    "view_dir = torch.rand(10, 3, dtype=torch.float64, requires_grad=False)\n",
    "rgbs = SphericalHarmonicsToRGB.apply(sh_coeff, view_dir)\n",
    "# print(rgbs)\n",
    "\n",
    "test = gradcheck(SphericalHarmonicsToRGB.apply, (sh_coeff, view_dir))\n",
    "print(test)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
